{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "8214df3c",
   "metadata": {},
   "source": [
    "## 3.2 线性回归从零开始实现\n",
    "这一节中，我们将从零开始实现整个方法， 包括数据流水线、模型、损失函数和小批量随机梯度下降优化器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8b18252c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f82c8a4",
   "metadata": {},
   "source": [
    "### 3.2.1 生成数据集\n",
    "简单起见，我们将根据带有噪声的线性模型构造一个人造数据集。 我们的任务是使用这个有限样本的数据集来恢复这个模型的参数。 我们将使用低维数据，这样可以很容易地将其可视化。在下面的代码中，我们生成一个包含1000个样本的数据集， 每个样本包含从标准正态分布中采样的2个特征。 我们的合成数据集是一个矩阵$X \\in \\mathcal{R}^{1000*2}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e7cafa56",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def synthetic_data(w,b,num_examples):\n",
    "    '''生成  y = Xw + b + 噪声'''\n",
    "    X = torch.normal(0,1,(num_examples,len(w)))\n",
    "    y = torch.matmul(X,w)+b\n",
    "    y += torch.normal(0,0.01,y.shape)\n",
    "    return X, y.reshape((-1,1))\n",
    "\n",
    "true_w = torch.tensor([2,-3.4])\n",
    "true_b = 4.2\n",
    "features, labels = synthetic_data(true_w,true_b,1000)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f9d053a",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "source": [
    "注意，features中的每一行都包含一个二维数据样本， labels中的每一行都包含一维标签值（一个标量）。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "423e1714",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "features: tensor([-0.3234, -0.0593]) \n",
      "label: tensor([3.7554])\n"
     ]
    }
   ],
   "source": [
    "print('features:', features[0],'\\nlabel:', labels[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26b1755f",
   "metadata": {},
   "source": [
    "通过生成第二个特征features[:, 1]和labels的散点图， 可以直观观察到两者之间的线性关系。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d9a9d754",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD4CAYAAAAJmJb0AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAA1L0lEQVR4nO2df3SU13nnvxeBUiRcjISs8EO2QJLlyDmONpnYRIHYWNDGXQ5u9izeZLcbnbQbwjkbR+tme1o3PsmmTTfd7bpe2rOnRW2T4rOpU9M2iY8TtzGEAIJiIzuEYsVYGiQQthdGI0EtjepB4u4f79xXd+7cd+ad3zMv3885PoNm3rnvfV9Z3/u8z31+CCklCCGEBJMl5Z4AIYSQ4kGRJ4SQAEORJ4SQAEORJ4SQAEORJ4SQALO03BPQWb16tWxtbS33NAghpKp45ZVXJqWUTbbPKkrkW1tbMTQ0VO5pEEJIVSGEuOD1Gd01hBASYCjyhBASYCjyhBASYCjyhBASYCjyhBASYCjyhBASYCjyhBASYCjyPpmajWPfkTCmZuPlngohhPiGIu+TA0MT+PoLr+PA0ES5p0IIIb6pqIzXSmZXqCXplRBCqgGKvE8a6mvxufvbyj0NQgjJCrprCCEkwFDk84QbsoSQSoYinyfckCWEVDL0yecJN2QJIZUMRT5PuCFLCKlk6K4hhJAAQ5EnhJAAQ5EnhJAAQ5EnhJAAUxCRF0J8QwhxRQhxVnvvvwkh3hRCnE7890uFOBchhBD/FMqS/0sAH7e8/5SUsjvx3w8KdC5CCCE+KYjISymPApgqxFhBgxmxhJByUmyf/OeFEGcS7pxVtgOEELuFEENCiKFIJFLk6ZRedJkRSwgpJ8UU+T8B0AagG8DbAJ60HSSlHJBShqSUoaampiJOx6HUorsr1ILHH7qLGbGEkLJQNJGXUl6WUi5IKW8A+DMA9xbrXNlQatFtqK/FrlALDgxN5Pz0QJcPISRXiibyQog12o+fAHDW69hSosoQNNTXluyc+T490OVDCMmVgtSuEUI8A+ABAKuFEJcAfAXAA0KIbgASwDiAzxXiXNVIvkXMWASNEJIrQkpZ7jm4hEIhOTQ0VO5pEEJIVSGEeEVKGbJ9xoxXQggJMBR5QggJMBT5HGHECyGkGqDI54jfiJdiLgbZjM1FiZCbE3aGyhG/ES9qMQBQ8A5S2YxdzHkQQioXinyO+G37V6jwx6nZOA4MTWBXqMWN8c9mbIZhEnJzwhBK2AU02+/vPzEGQKCvp7UoiVb7joTx9Rdex+MP3UVLnBCSRLoQSlryyN+VcWBoAnsPjQIA6mpriiLCtMQJIblAkUeygOZi1e8KtSAWnwcgiibCft1DhBCiQ5FHsoAqtwjg36pvqK/FY9s7izY/QgjJFYq8Ad0ihJAgQZE3oFuEEBIkmAxVQjIlJDFhiRBSaCjyBcCvOGfKks2lbjwXBkJIOuiuKQB+QzAz+ftz2Q8oRyZrvnkFhJDSQZEvAH7FOZO/P5f9gHJsFBdqYeFiQUjxocgXgHJu1pbj3IVaWGyLBYWfkMJCn/xNTK7+fK8+udmOZ2uqzn62hBQWinyZ8COIxd5ULbSgZjuebbGwCT8hJHforikTfvzahfB9p3N/+HW7+HWhFMKNwzwFQgoLRb5EmEKZThDVsdu6mj2P8Uu6hcKvoPpdbHIRaPrgCSkuFPk88CtQU7NxfPHZ0zh8LgIArotiV6jF+v1ChkVu62rGyfNRd8HIhWwt9GyEe/+Jcew9NIJYfAGPbb8z5zkSQuzQJ58Hpg/ay4d+YGgCh89FsLWzydcmYyH90geHL+PwuQgODl9OmZ9fn7/XRqsX6rr2nxj3Mb40XgkhhYSWfB6YFq5pgdvcLuYmo/6qKKRfWj+HOb9iJVKpp4e5+AL2HhpJO35fzwbU1S6tqI1WupBIkKDI54EpxplEP9P3M5GL+KhzTM3GEYvPo7+3I2WehRTYqdk4vvb8MA6fi+Ce9SszPpFU4kYr++GSIEGRLyCZRD9fvJKH9p8YByDR17PBU/xV96rHH7rLPaYYAqu7ptLNpxTkapGz3DQJEoER+Up8xM4kotnO2SY+jng7LpG62qVJ59PHL4VwOU8LC+jvbS+7wAO5W+SV+HRBSK4EZuO1GjMlC5U81N/bgf7e9hQBV+N/8dnT7rEHhiZSNkKzSbpKd6xacOpql5ZF4M25MbGKkAJZ8kKIbwDYAeCKlPL9ifcaAPw1gFYA4wAekVJOF+J8NqrxEbtQyUNeoYe7Qi04+kYEh89FsP/EGOpql1otW7UYnDwfxZOPdKcV6HTWcSGuJ58nMtt10CInNzuFsuT/EsDHjfd+C8AhKWUHgEOJn4tGtmF+xcaPdWybcyFLGTTU1yLU2pD4SXhatrtCLdja2YTD5yIZnyr0McKRGXzmmy8jHJnxvJ5syeeJbFtXM9qa6n1dByE3CwWx5KWUR4UQrcbbDwN4IPHv/QB+DOA3C3G+akAl+URn42hMJD7ZxM+0XP34kbOxdvt6WlFXW+Me6xXl8+Qj3e6Y6dDHWEzwGsY3P3Nv1nOzkc/TwMHhywhHZlPyEfxSifs6hORLMTdem6WUbwOAlPJtIcRttoOEELsB7AaA22+/vYjTKTVOcs9rb17D8XAUgF20TVH3I3L6d3aFWpKia9TnSqj8uixycW08+mAHLk7F8OiDHe57fjJY04lpPi4W/d7lItIMnSRBpOzRNVLKAQADABAKhQKT9qiSfLZ1NePg8GXf3aDSiZwtucqMrgFQMqE6NT6FcGQWp8an8ME7ViXezZzBmq2Y+rWw9ZyAfUfCDJ0kBMUV+ctCiDUJK34NgCtFPFfFoYt12/0rfB2XCZs47gq1IBZfACCTxEn9WxdINYZf11Gmz22iuLN7Hc5cuoad3es8r8OZ8zxi8QVMzcZT9iTMOWS7KKjjY/GFJFdVJrhRS4JIMUX+OQB9AH4/8fq9Ip6rIOTqky2VL9fb0pQAhPU7ukACSCt+i+I475YaSFc4zSaKqlbOpo2XPRe3hvpaN9KnrrYmKbHLLOSW/roXseUExOLzdL+Qm55ChVA+A2eTdbUQ4hKAr8AR92eFEL8G4CKAXYU4VzHJ1SdbKl+uzR2hMlkB4Mylq7hn/a1J9WJsAqmLn+PTHwMgsLN7beLzBev1+BFbr2PMJwqzxALgXchNX0y8FlTzd6DuU6XVxSGk1BQquuZTHh/1FmL8UpGrT7bUvlxz4zUWn8crF66m1IuxCaIpfvoioaxq5/OajIXTbON7uTzMJwqzxILKlt29ZQOW13r/b+kV02/7HeTqfmGUDQkSZd94rSRyFYVS+3LNKJLHtndahWnfkXDaJ4zpmFO0bPeWjViuiXqmDUx1rlh8HnsPjSI6E8fw29fQtWYl9jxgj5PX69qvqksWZmAxW1bF6+tuHPPaT56PurHw6phC/g5yfTLj4kAqEYp8FWITNNt7Xk8YujV8+FwEjz90l6/wTvP9ze2N6O/twCsXpjA4GsXgaBSNK1LnoVem3LTxsutOsc11W1czNm30jkbKJqY/V3J9MmMIJqlEKPIlplzWnlOtcgxz128AEujvbcfO7nVpBdVL7HRrektHE7768Pvxle+dRdealdaxvHztal7qfviJRgKK/+SU6/gMwSSVCEW+xJTS2rP5wRWPP3QX2ppW5BTeaVrTDfW1+L//aZP7udcma19Pa8rClul+ZLMoZtOOsRgLLUMwSSUSSJEvxh9xocYspbVnnisWn8fc9RtYvmyJW3vma88P44kdXWhrShV783PzHngJWqZNVq855hsjnyk+3txL8DNmIaCvnpSTQIp8MazlQo1ZSmvPbBb+2PbOpM9ttWd0lB/94tQQDuzp8WxvaIqXbSEzk7P0TGB1P8yNYj3ixpY4pcZTZR12dq9DLL6AofEpaykJNf/+3g48/tBd2NbVnFNmbLbQV0/KSSBFvhjWciX5W5V/HRBWF4hOOoF5YkcXgOHEaypP7OjCxakhhCOzSe4XPXvWljxlLmReZY1tSU+x+AJi8Xl3Mdh7aASb21djcHQSAFLq4ahaOYBT1qGutgbHw1Gr/19F+OzsXou2phUZo48KRSX9v0NuPgIp8sWwlkthgft9rLfFtnuRTmDamlYkWfDm+duaVrgWvF7KIDWzdCHFTaLmaf5bCe2jD3Zg08bGFAsfkIlrczJ4+3vbMRdfSIi8TDp2W1czhsanAAD3bViFWHwBO7vXumUezHurR/i03b+iZOJLXz0pJ4EU+WrF72O9SoBSNeLT4UdglGhGZ97FwLExxOLzrmtHF/ZtXc2uUKpkJJUxu/tjG/GP4Um8PD7tWvZmOQXA2Rc4fC6CjttuQeOKWkzHlJ98AXsPjbiuFOU3f/yhu9DXswGNK97j+u2Vm+nk+ahrtd+zfmWiK1UN6mprEiUTFtsh2iJ8mCxFbgYo8hVENpalrbZMrizGva9OvCOsnz87NIFwZBYN9csS3abGceaSk2nb2liH8WjM/b5+LdOxuJsI9dennGYeQxem8OrFqzj6RgTHw1H097ajv7cDeqE1W0kCXayf2NGVFAJqHm/7t9c9y7QJbbsfQOHdPFxASKGhyFcQfsoMK6s6X5ExC3rF4vOYnr0OAG4NG4UeF9/WVI9wZBatjXU4eX4SL41Nu+8BQGtjHXZ2r026FiXMmzZexvJlTjOyZTXO693rVuJjdzYlXZcS61h8AX/64zBee2uxJr8p1noIqO5KypQsZoqpekIxN6FtoltMNw83aUmhochXCWa9Gv013/GUmDx98gIAp5LkqlBtkripuHjdZTMejbkW9XOn33KjWg4OJ1egNEsamHX2beKp18kHgLamemzravYVummGUPoJzTQ3oc1wS3NzuRCLa6aQUkIKAUW+SjAt2HytvFQxcVw0m9tXW58W9HM++Uh3SnTPY9vvTEmC0oVSWfLqs1V16a9BWfLTsTgOv34F4cgsnjv9ZkoYqBlCCaSWGPaqw6+/tjWtcBcytcCpcMv+3na8cuGqG+GTz73PZKn7KQJHSDZQ5KuEYkdo7OxeizOXrrqWrK0UMLAoOn09GzK27lOCdt+GBny0rRHbuprd946+EUGotSEpBNQUwMe234l9R8KYmJ5LjCiS5mBa/HPxGxi58g4efbADc/EFHBuZxLauZtcdpcfa2+6nfv5tXc04+kYEc9cXsHzZEgyOTvrqHZspvDVbS53uG5IvFPmAkskCNMVDb/YBeGepeiVEeblf9Hh49bl673g4mhQCuivUgujMu/jR61cQnY1jz/1tSZFEfT2tKXPYFWpBdDaO1968hp9emsZLY9O4OBXDe3/+53A8HMXXnnd87LYmJeZ9Mt1Fx8PRxKZwh1u+OZM1nSm8NdvFmu4bki8U+YCSyQLU/eRA+ixVHa+EKBV5Y56v47ZbEIsv4APrVyb595W1a2adjlyZwUtjU3hpbAqNCUHUXTROFuziU0ZDfS0a62txPBzF7i0bMTkTRzgyi973NaN26RL3ycR2fbZOVMmJWU4ZCBXxk6klIoCkmviFEGbG2JN8ocgHlEwWoLLcgWG3+YbZBcr2JGB2aYrF592MVNOdcWBoAgPHzgMAHrzrNgBwBV0Jt5l1+sSOLsTnz+LudYsVLfWnBbXp+/hDd7nf1xeqPQ+0uT56vamIl3vGqzqmqtOv5qfH3JtjJNfpGfEs3UxIOaDIVzD5bLplsgB1t4nefENhdmBS7+kWvIo86e9tx5aO1SnzNJO2TDeLEm4ASRb9tz67yXMuh89FsLm9EbH4glvSQEW+qOu2JUN53QOv7FgzQ1fN0zaG/urVoNyG1++Xm62kkFDkK5hibrplar5hLgJK0J0N01UJcU/vq1bWsD4m4Ai67ibZFWpJ+VnfvNS/t2nj5cRcRtDf257IjnXKKkRn4xi57Gy8qnllugde2bG6dZ6uMblZ6iGT71/fu8jUlMVscUhILlDkK5h8Nt38WIPprH21CPzpkTCOvhHBuluXAwCOh6O4e91KbO1scgt9+Tm3GRGju0lsP5ubl3pDEdWjVgnmzu61qKutcbNnL07FcGBPjy9xTLcXob+XrmKlnxwGW2E2r2MzPWURkg0U+Qomn003v2WBM51/5PI7OB6O4o6GOgDA5vZGLF+2JK11u//EmJtApCz5dEJoNhUx3Ty2eX3u/rYUf/62rmb82l+eQjgyi/0nxjx739rGUng1P/eqWGnbCPbqTQsktzdMd+5itzgkNw8U+YDiFQWTrQtAZYI++mAHTo1PueOZdWKSo0xU7ZvFGjh6rDqwuPAoN40ermm6eYDF2jKPPtiBvz/7/zD89jX8+vbOJLdMW9MKPNy9LhE3LzA1G8ejf/UqjoejiMXn0dezIWOJZpsLRdW17+9tt1ro6ZqiKHRBVwujuaCY5/azwNN/TzJBka9gCrnxaroAbNEztvPp5Yg/eMcqd7x0iUR9Pa1JJYfVfEx/tVd0i61Y2Je/exbHw1GMTc66hdCW1Ywk1ZkJR2bwyoUp7N6yEX09rW6sOwDMXb+R5AYCkNI9yibm+kLU2liHues3sOf+toLUsTFFXW0Eq3r6fn7nTJYimaDIVzCF/AM2N1ptY9vcLH7Rxc7cjPQSRPNV+byVoMfnz7qRNnevW4nj4Sg+1tGEj3UAh89dQcuqOkzNxgE49+rYSASDo1G3+Fl0No6WVcsxMT3ntjxUbiBAWq5/3C13rGfhHj4XQUP9MoxHYxg4et6N31f3NdffjXn92UQGeY1BiAlFvoIp9B+wLkj2sVPdLIpMm6leJQ2A5AXKCX0cQ1/PBve96dhiQ49YfB7zN5yQxrvXrXSP2XN/Gxq1xWNieg5Pn7yAiekY7ll/K/YeGsEjofVoqF+Gvo84VvzA0fOJ616G+++8LckNpDZvk69fGq+LSWN9H2nFnx07j7vXrUy7CZvN05dtgcj2d85kKZIJinwFU8w/YNvYNjeLwiba6Z40bGKl15l5/szb6H1fMwaOnnfdSFs7mwAIvDQ2ha2dTdiT8Nv/6Y9HMfz2O/jqw3drG7NOL9fD5yK4vnAjUURsGlOz1/G73x/GwKdDiM7G8e1TFzA1ex1P/vAcvvXZTUkinHr9G1KEfzFpDPjjf/9BNNTXWnvRmslatntiw1wUWKCMFBqKPHGxCb8uYEDmUMN0Yylx/u5P3kQ4Mov3/vw1t6F2cvOPxYXmC8+8isFRx6+u6tDoVS+Vv3xLRxO++vD7sftppyft154fxj3rV+Kf55yN3rbbHN++WYQsU6OQXaEWHBuZdJukPLb9zqRImX1Hwm4Ogb5Y+bXEM0VB0edO8oUiT9KSLion2ycNJc47u9e6kTKnxqdSyg4rl0wsvuAK/Kq6ZXj0wY6U8fR9hob6WhzY04MvPPMTp8Vg8y3Y3N6IwdEoVtU5deWjs3F8tK0RH25tcBcE1ShkMYN2wW0Y3lBfi42r6zE4Oom3r83hqRffwFzCr/+V753F4OhiAbMPtzYAcCKSvKxu88lEjzpSZZP1OZifp9skJ8TGkmKfQAgxLoT4JyHEaSHEULHPRwrLrlALtnY2uVE5hUBF7Jwan8LXX3gdB4YmMDUbx74jYVe8HOtVor+3Ax9ta8R07DpOJZp266iFRg+/7Fpzi/OhlPijT30Q/b3tAJyN5YGj53E8HMWTPzyHcGQWLauWo+M2J8FqLhHeqV4V5ydnAAAvnZ/C3kMjGDg2hoFj5zE46vSX7etpxefub8OphPvo4PBlz2vff2IcA8fGMDg6iS9/96wbdbT30Eji/srEHOax70gYAIzPHdQ9KtTvhASXUlnyW6WUkyU6FykgmcofKHKxLNOVOdBfAYlQa4NvF8jyRB0bCJEQQeFGzag+stOx62545cCxMYxcmUFHwqWzvHZJkpuqa42zAfzr2ztx5I2Ia8kvr63Bzu611lLFXjjfdVAby+b36mqXuqUazM/Tuc9MaO0TgO4a4oN0vvp8fMd65qruy9bPt+9I2FeykT4nVeZAdYhSNW7U2FOzcXzhmZ8AACam59DWVK9t4Ha4MfZ6GYKtnU1oXV2flCug5qf7+FUxM3OuKva/6Zb3AADu29DgbizbMmyd6J/FOH7dhaU2rzPdZ3UNZutCcnNRCpGXAH4ohJAA9kkpB0pwTlJkbIk8QG7hnmaMvddnfud08nzU7Tvb39ue0sXqwNAEBkcnsbm9EV1rVwJSYs3Kf8bgaBRbOpowHYvj2EgEu7dsxL+7twXAcFIdGd2a1pOn1NPI2OQpPNy9LimrVkXdtDY65SE2bWxMitQx9zy8QlL1BcvvfdWfCrh5e/NRCpH/qJTyLSHEbQBeFEK8LqU8qj4UQuwGsBsAbr/99hJMhxQCWyJPrgJi1qg3qzX6DSnUs3rHJk9hPBpLSmyyzV2J5+4tG7GsZokbcaOSqvTer7YSEXpJhid2dOHilLOZu/fQSFIlSlUeou8jrfjzwTHMJbJa/RYjM+v5+LHI1X3VnwrIzUfRRV5K+Vbi9YoQ4jsA7gVwVPt8AMAAAIRCIWkdhFQcxYrht1Vr9FuG1xRah8X/pUx/NrAo+NGZuBORc9sEOm5bgVh8AR3Nt7gRLUosn3rxHOauO3H5O7vXJYV+tjWtwIE9PW59HF1U1WLxxWdPY3B0EoOjk2hc8R63UQowjA+3NliTrJSrp6P5FgwcPW9tK0j/O/GiqCIvhKgHsERK+U7i378A4HeKeU5S3ZjVGs3s0m1dzXg2UWpg/4nxlPozB4cvIxyZxeb21fjQHbeir2eDK4Aqnv2Zly9iPBpzi5YBwL/MOxE1QxeiePXiNXy0rREDR89j+bIa9xx67ZvHH7oLbU0rUqpw2oqrKVSJBDU3da0q4er6gsTg6CSeefki/vCRbneT94fDlzEejSE+f8PTVZNuT4Sx9jc3xbbkmwF8RwihzvVXUsq/L/I5SRVjVms0s0uViDvZsan1Z2zVN9Vm5e4tG9HWVO9a+a9cuArAiUvf3N4IAFhWUwPAiXz52J1NiM6863af6uvZ4PR9jd9ALL6AcGTGdSmpc5mWtFejcLUBrLcvjM68i8HRSYxHY/iNv/mp9jQCrF35c7g4FXNj8U2L3xaBo2/a6scUCj49VAdFFXkp5XkAHyjmOUjlkqsIeAmjelUt+3Z2r0spQ2BG5uiblY4wz7qCPjg6iQ/dcaubdXtw+HLSXkBDfS3++/d/BgCYjl3HgaEJ9PVscC3jM5euJurtLLj/BuzuJf0abJ8pdxAg3DLKypJfXrsU3/3Jm7gw5Yj/I6GWlMXNdt26S8tWMC6f35Ft/qQyYQglKRq5ikC6uupelRptYmVazuHIDM5cuprwgTsRLzu717klDZTrRXfBLK91LPvDr1/BxPRckotHb0eoh4CqzFW1EKk5mIJvq1HfuKIWf/Qpp0aOHq55/51N+I2/+Sn+4N9+AK2r65Ouz8S2metVPkG5sNT72ZDpCYGWfmVAkScFw/yjztQE24t0rgfzc4VXSQJduJTve9NGJyM1XXcrRV9Pa5KVDogUl5Je0bKhvhZPvXhOa1+4NMWVpCd/6fH/6RbFD96xCoe++ID7sxJq20at2oAGhlOSpswIIb1Pb7Zk2nynpV8ZUORJwTD/qHVRTSekJjbXgxrT/FyhlyTwsiBti4PXQqLP5clHurH/xLibrWo29NCjb/YdCWPu+g0ATqtEc3xb8pdj+Y+5UTs2wbXNMZ2ImvfenKMu/sWysou1F0CygyJPCobNf66/2sj0SO81Rur3FmvBp6vsqIuh10Ki3C3KH97X04rHtt/pHmcLYQTsCUtmaWLdTaOStFRWL4Aky95sfp5pk9nPfbM98RQL1rqvDCjypGCYf9R+/sgzPdJ7bRqa31P1apZrG7Gma8LrHEp4Vds9ve494JQ+3tbVjB+9fhn3bWiwup/0Cpe6n1+fx2IY5gi2djqZtYuZs4uJTjZ/uU20091f78+cxXBofMp3i0H61qsbijwpK36sfT9W7M7utThz6Sp2dq9N65qwYW7m7gq1IDobx+mL0+i+fZVbsuClsWkAjivEdD/pnaj0z23hjIuJXosNRvTY+sWooI6kJ4JsE6Bsn/f1bMCZS9cyZthmuv+keqDIk7Lix9rXNyyfevENABJ9PRuSvvfc6bdw+FwE96x/K6WhCLAoTun89cqaB4Dly2rw8vg0ultWYf+JcbSsqsN9G1bhA+tXWf3427qacWwkgq41K935mnNQlvwTO7rcRC89Y9Z2venKFmcSXzOaR0X8PLGjC/esfwvRGSeD16zt43X/GUVTnVDkiUul/rEqC1gXzNRG18n9WVV2qdmlyRS+A0MT+HBrA/74RyNoWbUcT5+8mHTu4bevuY1LtnY2Yc8DbVY/u6p3s6Wjyb13+0+Muxmuul9dj1+3bUjr1Tm9RHxqNo5YfB79vR1pQyn161QuqLrapairrUn6Od1Cyyia6oYiT1zK+ceaqVH4/hNjOHwugvs2NGDTxtTa8np/1qnZOKIz72Jz+2o3Jl6FGprC9/UXXnezYFtWLQcAHBi6hG3va3br0zx3+i23n6yajy6wXguKWnA+dMetaKivRWfzLVi+bIlvV4lpQZsbsXoJZj02X1nmujjrSWRqPPPnXGEUTWVDkScu5fxjtS0wi/XQF/DKBccf/oH1t6KudvF/W1vkzL4jYQwcGwMA191hS65S17loydfh6ZMX8Na1f8HTJy+49WmU+8dLYM2kq8W69uuSFp5fP3Aac9dvoKF+WdI9TrfAbetqTlro9h4a1UosLLguF9UYBVi0zM1xzWgaP9E15hi2uTKKprKhyBOXcv6x2hKnFi3OebfV3vLEBikAayan+p4eraKfQ/fp69f7zc/ci6nZOM5dfgcvjU3hvg3JTwumVaxe/cavO60Nr6Ohfhn+/NMfTqlvbzb3sFXjBAQAYO76DTgdr4CBY85mb39vO3Zv2Yjht6+59zBdxU7An3vOvBa6ZqoPijypCGyJU8n10BdDI/Xa6F6hhY9t77QmIdn80Ppxf/IrH/IUPq+nhkyRP7r/XG8kolDHRWfj2HvIeXLp62l1iqFdv+GWPr7/ziacuXQVkDJRA38DNrevRteaW9yaOoOjUTe6R4/ksbmH/Ah2LrkPpLKgyJOKIJ14mE8YfmPxbd2rorNxvPbmtaQnBq9aOV4tDs3CX+a8bV2d9h4aTVTOXMSscfPl755NfCITC4HAwNHz2Ny+GoOjk3j+zFsIR2Zxz/pb0d/bjlcuXMXg6CS2dKx25xKLzyOWyPoFgHvWr8Q962+13lc/gp1L7gOpLJaUewKEAIvi4TeqR8XAKzGzsSvUklSXpaG+Fo31tTgejiaFJm7rasbWzqaUJCcl6so1sivUgq2dTa5l7DVvc26276nx9x4awd5Do/jyd8/ieDjqJlPtOxJ2yyhsXF3nbg5v7WxCX08r6mqXYnB0Mmmz14n3X4q9h0ZwYGgikd066n5mzq0Y95xUHrTkSVXix9VgWp2O2yS18qNXjR3T0lV1bJR173duXt/TI17mrt/A8XAUodYGHBy+nJQMpUokb+1ssj5BeNXncTpUAf8YngSApAbl6phsQmbTlYuopJBbkgxFnlQlNldDJtHRSwoowpEZt2m3vpGq15UHkpt06O4cVVgMUrp1brw2Zs3FSI94mZqNY/myGgAypXiY2aNVzcWrDIR6X89ufXl8OmUvI9tNVPOe57sJy0WiNFDkSVVis9JtGa46to1IvWm3nuSUHNUCq5gpX7uOKl6WbmPW63psdfLNa003nuoF+8SOLrQ1rdAqaI5hLn4D0cSipCKLst1ENe95vpuwjNQpDRR5EghUQlJbU71n/Xqb20TVXVdJU3pJAbPkgK38gIqAUZa8ab3roaG2p4TU0grzKa0FM5VLVnzt+WG3Js43P3Ove82Pbe/EUy++kRJZZNb4ydai9rMJm85aZ6ROaaDIk0CgW+m2AmIKU5jamla4gmh+bvrnbQJta9qtW9vAYnMS1Wv22MgkBkcdP7kpksq98nevXsLE9ByiM3H89r9+n3V+poA+saML1xfOouO2WyyLhJN9+9G2xhRRVeWHj41E3K5UhSKdtc5IndJAkSeBwO+maCa8LE9bcpJNbAGkbO6qwmeqsUnXmluwpWN1SrmCWHzBfRpRTbx/emnazWjVY+x191R05l2MXJnBEzu6sKWjCV9/4XWMXHknaZ562YdUEXcWgMHRqO/KlOnulY7+ZOT1tEDffHGhyJPAkK1lmE23JS83ztRsHF945lUMjkbdjNW9h0aSmn8oX3t/bwf6ezug14vZdyTsWviq2ci2rmb89csTGH77GrrWrNRaCda4TxR6X9mhC9N49eJVxOJn8Ce/ErLOM929cXrWCqhNX5sYZ9uZyvxOumPpmy8uFHlSNsphweXabcnLjaOyTAHg5PkpfGD9rW7hMt0Hr8ZU51T1d1Tm6tbOpqRmI8pFMzUbx/LaJZi7fgOx+LzrWtncvtotoPbpv3gJADAXv5F0P02XldeegIryCUdmsPvpIfcpwtxkzqYzlfmddMfSN19cKPKkbJTDgsskPJmeBmxNQGLxBQyNT+F4OIqXxqZcK95PyQMArptGj9PXz/PY9k6tkUi7m1i1pWM1Dg5fxptX/yUxknQLuinR1qNt1LU/OzRhFfKvPT+McGQWDfXLUjav/YSF2u6N+k5uXaxIIaDIk7JRKgtOFx+/wuOF2SNVWcEqZl4vipZJGMORGZw8H8WjD3YkxbAD9pIM+pj6HsCxkQgGR6NYXlsDAJiejeMz33wZ1+au49WLVxGfP4tvfXZT0uZ0allkJ9Lo4pRjyZub1/q9UpE6Zp9YvWqoaplIyg9FnpSNUllwXrVp0uHtSnI2KU+en8RTL8qEP9s5h9lhKVN8+2LI42IVTFvde3MsNYaa41cffj8ODl923TBK9Fcud/6871630h3DiZt3mpTvPzGOnd1rXddNW9MKHNjT4163vnAlF1aTxiuS5hqLz/vapOYma2mgyJPAYwpmLiV2FXoW6Utj00m+dXWsbXzbU4sZo6/qxR99I4Lf+eX3Z7wu2xzb7l+BbV3Nrm99a2cT9hgCC0i33v6ZS1eThNhcmPRNXz2T1kEkNQM3q4bacg1YGqH0UORJ1ZCrINiqQmZbYlc/t24Nx+IL2Nm9NulY2/i2pxYzRl/Viz8ejiZZ+X7nqI+rW+SqNIIKuXSifNoBCOzsXmvtM6vGtdXlV4XQnA3kmpT5eW1S2+bs53fBhSA/KPKkaijURq2fvYBMC8Nj2+90XTBK6JS7RY+myYa+ntbEv5zSw17i6zXHTPNXvnizpr2+2Wu2ENSTvbwyebPBzLL1c68YYpkfFHlSNRRqozaXvQDbJqopUPmKkdmizytrV8evlavH+eubtnpIZXTmXdeNY9bPMWsDAbBW7vRLNveKIZb5UXSRF0J8HMBeADUA/lxK+fvFPicJJuUMtfNTJKwcYuQllqb4q/mbJRf0DdLN7asBAPdtaEAsPp/kb9efBLZ1NeO502+iv7cjbSariVeUUybyrbFzs1NUkRdC1AD4PwC2A7gE4JQQ4jkp5XAxz0tIMfGTMFVsbE8TfjJMlZ9d30tQG6TKoo/F57H30GiSNa9fs6q++fhDd7m1783z2PDbgcvv94k/im3J3wtgVEp5HgCEEN8G8DAAijypWnIV80JuINoETy+TbFbV1Oeub5oq0dazZPXoGFsdfXOR01sO+nEZmXPyK9502+RGsUV+HYAJ7edLAO7TDxBC7AawGwBuv/32Ik+HkPKRrSWaqUyvKa56stOmjZd9ReWkiwTyckupz1VGbcdtKzBwbMwaaaPjtTjaxNt27cyMzY1ii7ywvJeUQSGlHAAwAAChUEhajickEGRriWYq06ss8jOXrrptAf1U4kxnlXvN1ya6Kszz+oJM6qWbLTbxLrZr5mYKyyy2yF8CoP/m1wN4q8jnJKQiycYSdfrRzrvFzmzolrsqEZyttWvrsKWLn+kKUiULdoVakpK5VGG1QuFnQcxHqNMtIkFbAJYUefxTADqEEBuEELUAPgnguSKfk5CKRkWJTM3GPY9Rm5t1tTWeQqMs93ysaHMuSvyczNhFdoVa8PhDd2EuUbJg/4kxN5krG4H3c+0AXPfTgaEJz2O95uoHdT22+5bPuJVIUS15KeW8EOLzAP4BTgjlN6SUrxXznIRUOjYrMl0Fx3Sk6xSVy1zSRQ7tCrXgC8/8JPGOzROb/flsLDZRmXfLKvjx5Xtdf7b+/aBt8BY9Tl5K+QMAPyj2eQipFmwiYopfLpuMufixMxVCM8cfHJ10s2a9yLevq7qO/t6OtE8pXlnJJ89H3T0K/X3A330J2gYvM14JKTE2ESmE9ejHsjVbFfq1fM3xc41n190w6Rqa+zmPPtdtXc2IxRewub0xaY/Cdl8yXWfQoMgTUkK8hKUQ1qOfQmz6ewBSNlMb6mtTauZnO8dMC5ZucduKsGVzL8yx+nvb8aE7kjN2vca7WZKrKPKElJBSCYsTnZPcUBxQ8fVOY/Gd3esAANGZd7H30AjenI5hYnoOLauWJ46W7ljpLG+TTCKt19HJVIQtE7ax9BaLTK6iyBNSUDK5AEolLE50zgj6eztSatioxuJKBJ968Q0AwNGRSYxHY2htrEN/b3tSQxS9baBp+WdLulLEfvAK81Rj5bJpHWQo8oQUkEyWeqmERS85kKnPrNpEnZ59F4Aj9GcuXUsaS28bqHrJ6mOmIxvfdz4NXRS53OMg++cp8oQUkHK6AGwWrqpDo8/HFEFl3e89dBH9vR1utyg9wUpPfFpVV5syZjoyibJXYTWzrg7g9MU9NjKJ3Vs2FPQeB9k/T5EnxCd+rL1yugD8dqSyYS5OZmmEg8OXk2riZHONut/cVirYFHb1aruerz0/jMHRSSyrESm/g3ys8SD75ynyhPik0q29fITK1ijca+xsxTRdwTOzfINZV0dtEqtIGbMvrk42TwzmvIPsn6fIE+KTcll7fkW1mELlp2mKiZ8sXr02vU14zU3i1L64i6T7/ZidrYIq6DYo8oT4pFzWnlcmp0mpNg/99nf1k8WbyZWTbaGydNm6auM4iC6ZdBS7QBkhJE92hVqwtbPJ3Qz1wlZYy29BMD+osZ47/RYOn4vg4PDltMdmqqIJLC6cqsOUeX3qcz8Ztuq7tmtWBcnSLZJBhZY8IRWOXic+XU/VdDVxYvF5NyImV5FbrCnTnrHypXLD9Pe2Z9Vo3GvMbOrh5LMBHUQo8oRUAZm6NenH6CzGyy/4avidDq+aMlOzcew/MQZAoK+nNcn/7nVeP+hzs5Va8HLTBDlSJhco8oRUEdkKWHK8fE3K97KJGEpXA0aVBNZbAMbi85i7fiOltIINU8TNjdLFhnKLzeO85n4zW+02KPKEBIxM9dP1zwtV/TIWnwcgklwmSvhtkTOpOOI9F5/HviNhxOLzKRulZgJWpmiadE8oQc5wNaHIE1JF+LG8Mx1jfp6L1WuK5GPbO5M+twm/7fuq6NnO7nWoq13qunf0WvJKhP24p3K9B0GGIk9IFeHH8s50TCGsdz8i6bXRq7tizHLDulspHwu7FPcgH0r5JCGklJmPKhGhUEgODQ2VexqEkAzolvhzp9+E2nQFkGjdt4C9h0Zca1wXNLV5vLWzCU/s6EopX+ynsUm1o+7B4w/dVZAnCSHEK1LKkO0zWvKEkKzRo330TVcAKWGW6frINtTXppQbtjU2Ud8tJqW0rkv5JEGRJ4TkjJfvXRdKU9D8NhWxbbIWk1L66UsZAUSRJ+QmotDWqtp0TVdaIB9BK6UYlttPXyxY1oCQmwhb6QMvsimJkM24mc6R61jZztnETwmFaoSWPCE3EdlYq9m4L3K1gvefGMPeQ6OIxefdMMx8LOqbKTTSLxR5Qm4isnF/ZCO2ubtVhPGan4smqC6XfGAIJSGkbNxMmafFJF0IJX3yhJCy4ccPXshyyfniNZdKmqMJRZ4QUnR0EfQjiOqYcGQGX3z2dM4bsYXGq2Z/Jc3RhD55QkjRyTbBSe+GVUkdnbxq9lfSHE2KJvJCiP8G4LMAVK3Q35ZS/qBY5yOEVC7ZJjjpbQE3bbxccp+9115Bupr9lbqvULSN14TIz0gp/5ff73DjlRBSbPxs9ha6tkyxYe0aQghJ4CeWPkihmMUW+c8LIT4NYAjAF6WU0+YBQojdAHYDwO23317k6RBCbnb8CHiQukvl5a4RQhwE8F7LR18CcBLAJJyWL78LYI2U8lfTjUd3DSGEZE/R3DVSym0+J/BnAJ7P51yEEFIo/Pjlg5KoVbQ4eSHEGu3HTwA4W6xzEUJINvgpgpZPobRKopg++f8phOiG464ZB/C5Ip6LEEJ84+WXL3ST80qgaCIvpfyPxRqbEELywWtjtRBNzjNRajcQQygJISRBKaz3UpdDpsgTQkiCUoROltoNRJEnhJASUuoYfFahJISQAEORJ4SQAEORJ4SQAEORJ4SQAEORJ4SQAEORJ4SQAEORJ4SQAEORJ4SQAEORJ4SQAEORJ4SQAEORJ4SQAEORJ4SQAEORJ4SQAEORJ4SQAEORJ4SQAEORJ4SQAEORJ4SQAEORJ4SQAEORJ4SQAEORJ4SQAEORJ4SQAEORJ4SQAEORJ4SQAEORJ4SQAEORJ4SQMjM1G8e+I2FMzcYLPnZeIi+E2CWEeE0IcUMIETI+e1wIMSqEOCeE+MX8pkkIIcHlwNAEvv7C6zgwNFHwsZfm+f2zAP4NgH36m0KILgCfBHA3gLUADgoh7pRSLuR5PkIICRy7Qi1Jr4UkL5GXUv4MAIQQ5kcPA/i2lPJdAGNCiFEA9wL4x3zORwghQaShvhafu7+tKGMXyye/DoD+3HEp8V4KQojdQoghIcRQJBIp0nQIIeTmJKMlL4Q4COC9lo++JKX8ntfXLO9J24FSygEAAwAQCoWsxxBCCMmNjCIvpdyWw7iXAOjOpfUA3sphHEIIIXlQLHfNcwA+KYR4jxBiA4AOAC8X6VyEEEI8yDeE8hNCiEsAPgLg+0KIfwAAKeVrAJ4FMAzg7wH8Z0bWEEJI6ck3uuY7AL7j8dnvAfi9fMYnhBCSH8x4JYSQACOkrJyAFiFEBMAFj49XA5gs4XRKRRCvK4jXBATzunhN1UO667pDStlk+6CiRD4dQoghKWUo85HVRRCvK4jXBATzunhN1UOu10V3DSGEBBiKPCGEBJhqEvmBck+gSATxuoJ4TUAwr4vXVD3kdF1V45MnhBCSPdVkyRNCCMkSijwhhASYqhJ5IcTvCiHOCCFOCyF+KIRYW+45FQIhxB8IIV5PXNt3hBC3lntO+ZKua1i1IYT4eKLD2agQ4rfKPZ9CIIT4hhDiihDibLnnUiiEEC1CiMNCiJ8l/t/rL/ec8kUI8XNCiJeFED9NXNNXsx6jmnzyQoifl1L+c+LfXwDQJaXcU+Zp5Y0Q4hcA/EhKOS+E+B8AIKX8zTJPKy+EEO8DcANO17D/KqUcKvOUckIIUQPgDQDb4VRXPQXgU1LK4bJOLE+EEB8DMAPgaSnl+8s9n0IghFgDYI2U8lUhxC0AXgHwy9X8uxJOR6Z6KeWMEGIZgEEA/VLKk37HqCpLXgl8gnp41KivNqSUP5RSzid+PAmnNHNVI6X8mZTyXLnnUQDuBTAqpTwvpYwD+DaczmdVjZTyKICpcs+jkEgp35ZSvpr49zsAfgaPZkXVgnSYSfy4LPFfVrpXVSIPAEKI3xNCTAD4DwC+XO75FIFfBfBCuSdBXHx3OSOVgxCiFcC/AvBSmaeSN0KIGiHEaQBXALwopczqmipO5IUQB4UQZy3/PQwAUsovSSlbAHwLwOfLO1v/ZLquxDFfAjAP59oqHj/XFAB8dzkjlYEQYgWAvwXwX4yn/6pESrkgpeyG84R/rxAiK/daXqWGi0EWnaj+CsD3AXyliNMpGJmuSwjRB2AHgF5ZJRslOXYNqzbY5ayKSPit/xbAt6SUf1fu+RQSKeVVIcSPAXwcgO8N84qz5NMhhOjQftwJ4PVyzaWQCCE+DuA3AeyUUsbKPR+SxCkAHUKIDUKIWgCfhNP5jFQYiU3KvwDwMynlH5Z7PoVACNGkou2EEMsBbEOWuldt0TV/C6ATTtTGBQB7pJRvlndW+SOEGAXwHgDRxFsnqz1qSAjxCQB/DKAJwFUAp6WUv1jWSeWIEOKXAPxvADUAvpFoiFPVCCGeAfAAnPK1lwF8RUr5F2WdVJ4IITYDOAbgn+BoBAD8tpTyB+WbVX4IIe4BsB/O/3tLADwrpfydrMaoJpEnhBCSHVXlriGEEJIdFHlCCAkwFHlCCAkwFHlCCAkwFHlCCAkwFHlCCAkwFHlCCAkw/x/PxfL1Rsa9kQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import os\n",
    "os.environ[\"KMP_DUPLICATE_LIB_OK\"]=\"TRUE\"\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.figure()\n",
    "plt.scatter(features[:,1].detach().numpy(),labels.detach().numpy(),1)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "255c1492",
   "metadata": {},
   "source": [
    "### 3.2.2 读取数据集\n",
    "\n",
    "训练模型时要对数据集进行遍历，每次抽取一小批量样本，并使用它们来更新我们的模型。 由于这个过程是训练机器学习算法的基础，所以有必要定义一个函数， 该函数能打乱数据集中的样本并以小批量方式获取数据。\n",
    "\n",
    "在下面的代码中，我们定义一个data_iter函数， 该函数接收批量大小、特征矩阵和标签向量作为输入，生成大小为batch_size的小批量。 每个小批量包含一组特征和标签。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e1aaf9ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "def data_iter(batch_size, features, labels):\n",
    "    num_examples = len(features)\n",
    "    indices = list(range(num_examples))\n",
    "    # 这些样本是随机读取的，没有特定的顺序\n",
    "    random.shuffle(indices)\n",
    "    for i in range(0,num_examples,batch_size):\n",
    "        batch_indices = torch.tensor(indices[i:min(i + batch_size, num_examples)])\n",
    "        yield features[batch_indices],labels[batch_indices]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ccfcfad",
   "metadata": {},
   "source": [
    "通常，我们利用GPU并行运算的优势，处理合理大小的“小批量”。 每个样本都可以并行地进行模型计算，且每个样本损失函数的梯度也可以被并行计算。 GPU可以在处理几百个样本时，所花费的时间不比处理一个样本时多太多。\n",
    "\n",
    "我们直观感受一下小批量运算：读取第一个小批量数据样本并打印。 每个批量的特征维度显示批量大小和输入特征数。 同样的，批量的标签形状与batch_size相等。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "71364155",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 0.9150, -1.5341],\n",
      "        [-0.2922, -1.0456],\n",
      "        [ 1.3034,  1.1610],\n",
      "        [ 0.0903,  1.0224],\n",
      "        [ 0.9239, -0.0307],\n",
      "        [-0.2999,  0.6628],\n",
      "        [ 1.6242, -0.5583],\n",
      "        [ 0.1993, -0.2599],\n",
      "        [-0.1996,  0.1486],\n",
      "        [-1.9182, -2.3965]]) \n",
      " tensor([[11.2512],\n",
      "        [ 7.1768],\n",
      "        [ 2.8552],\n",
      "        [ 0.8987],\n",
      "        [ 6.1418],\n",
      "        [ 1.3456],\n",
      "        [ 9.3358],\n",
      "        [ 5.4837],\n",
      "        [ 3.2794],\n",
      "        [ 8.5036]])\n"
     ]
    }
   ],
   "source": [
    "batch_size = 10\n",
    "\n",
    "for X,y in data_iter(batch_size,features,labels):\n",
    "    print(X,'\\n',y)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "693aa98f",
   "metadata": {},
   "source": [
    "当我们运行迭代时，我们会连续地获得不同的小批量，直至遍历完整个数据集。 上面实现的迭代对教学来说很好，但它的执行效率很低，可能会在实际问题上陷入麻烦。 例如，它要求我们将所有数据加载到内存中，并执行大量的随机内存访问。 在深度学习框架中实现的内置迭代器效率要高得多， 它可以处理存储在文件中的数据和数据流提供的数据。\n",
    "\n",
    "### 3.2.3 初始化模型参数\n",
    "在我们开始用小批量随机梯度下降优化我们的模型参数之前， 我们需要先有一些参数。 在下面的代码中，我们通过从均值为0、标准差为0.01的正态分布中采样随机数来初始化权重， 并将偏置初始化为0。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "42f45b6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "w = torch.normal(0,0.01,size=(2,1),requires_grad = True)\n",
    "b = torch.zeros(1,requires_grad=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0dc3866",
   "metadata": {},
   "source": [
    "在初始化参数之后，我们的任务是更新这些参数，直到这些参数足够拟合我们的数据。 每次更新都需要计算损失函数关于模型参数的梯度。 有了这个梯度，我们就可以向减小损失的方向更新每个参数。 因为手动计算梯度很枯燥而且容易出错，所以没有人会手动计算梯度。 我们使用 2.5节中引入的自动微分来计算梯度。\n",
    "\n",
    "### 3.2.4 定义模型\n",
    "接下来，我们必须定义模型，将模型的输入和参数同模型的输出关联起来。 回想一下，要计算线性模型的输出， 我们只需计算输入特征$X$和模型权重$w$的矩阵-向量乘法后加上偏置。 注意，上面的$Xw$是一个向量，而$b$是一个标量。 回想一下 2.1.3节中描述的广播机制： 当我们用一个向量加一个标量时，标量会被加到向量的每个分量上。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "59b335e7",
   "metadata": {},
   "outputs": [
    {
     "ename": "IndentationError",
     "evalue": "unexpected indent (2894703169.py, line 3)",
     "output_type": "error",
     "traceback": [
      "\u001B[1;36m  Input \u001B[1;32mIn [8]\u001B[1;36m\u001B[0m\n\u001B[1;33m    return torch.matmul(X, w) + b\u001B[0m\n\u001B[1;37m    ^\u001B[0m\n\u001B[1;31mIndentationError\u001B[0m\u001B[1;31m:\u001B[0m unexpected indent\n"
     ]
    }
   ],
   "source": [
    "def linreg(X,w,b):  #@save\n",
    "    '''线性回归模型'''\n",
    "     return torch.matmul(X, w) + b"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3776881b",
   "metadata": {},
   "source": [
    "### 3.2.5 定义损失函数\n",
    "\n",
    "因为需要计算损失函数的梯度，所以我们应该先定义损失函数。 这里我们使用 3.1节中描述的平方损失函数。 在实现中，我们需要将真实值y的形状转换为和预测值y_hat的形状相同。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "93b5c596",
   "metadata": {},
   "outputs": [],
   "source": [
    "def squared_loss(y_hat,y): #@save\n",
    "    '''均方损失'''\n",
    "    return (y_hat - y.reshape(y_hat.shape))**2/2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6779e27",
   "metadata": {},
   "source": [
    "### 3.2.6 定义优化算法\n",
    "正如我们在 3.1节中讨论的，线性回归有解析解。 尽管线性回归有解析解，但本书中的其他模型却没有。 这里我们介绍小批量随机梯度下降。\n",
    "\n",
    "在每一步中，使用从数据集中随机抽取的一个小批量，然后根据参数计算损失的梯度。 接下来，朝着减少损失的方向更新我们的参数。 下面的函数实现小批量随机梯度下降更新。 该函数接受模型参数集合、学习速率和批量大小作为输入。每一步更新的大小由学习速率lr决定。 因为我们计算的损失是一个批量样本的总和，所以我们用批量大小（batch_size） 来规范化步长，这样步长大小就不会取决于我们对批量大小的选择。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "d68f7161",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sgd(params,lr,batch_size): #@save\n",
    "    '''小批量随机梯度下降'''\n",
    "    with torch.no_grad():\n",
    "        for param in params:\n",
    "            param -= lr * param.grad / batch_size\n",
    "            param.grad.zero_()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b8486d76",
   "metadata": {},
   "source": [
    "### 3.2.7 训练\n",
    "\n",
    "现在我们已经准备好了模型训练所有需要的要素，可以实现主要的训练过程部分了。 理解这段代码至关重要，因为从事深度学习后， 相同的训练过程几乎一遍又一遍地出现。 在每次迭代中，我们读取一小批量训练样本，并通过我们的模型来获得一组预测。 计算完损失后，我们开始反向传播，存储每个参数的梯度。 最后，我们调用优化算法sgd来更新模型参数。\n",
    "\n",
    "概括一下，我们将执行以下循环：\n",
    "\n",
    "- 初始化参数\n",
    "- 重复以下训练，直到完成\n",
    "  - 计算梯度$g \\gets \\partial_{w,b}\\frac{1}{|\\beta |}\\sum_{i\\in\\beta} l(x^{i},y^{i},w,b)$\n",
    "  - 更新参数 $(w,b) \\gets (w,b)- \\eta g$\n",
    " \n",
    "在每个迭代周期（epoch）中，我们使用data_iter函数遍历整个数据集， 并将训练数据集中所有样本都使用一次（假设样本数能够被批量大小整除）。 这里的迭代周期个数num_epochs和学习率lr都是超参数，分别设为3和0.03。 设置超参数很棘手，需要通过反复试验进行调整。 我们现在忽略这些细节，以后会在 11节中详细介绍。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "4485a029",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'linreg' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[1;31mNameError\u001B[0m                                 Traceback (most recent call last)",
      "Input \u001B[1;32mIn [11]\u001B[0m, in \u001B[0;36m<cell line: 3>\u001B[1;34m()\u001B[0m\n\u001B[0;32m      1\u001B[0m lr \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m0.03\u001B[39m\n\u001B[0;32m      2\u001B[0m num_epochs \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m3\u001B[39m\n\u001B[1;32m----> 3\u001B[0m net \u001B[38;5;241m=\u001B[39m \u001B[43mlinreg\u001B[49m\n\u001B[0;32m      4\u001B[0m loss \u001B[38;5;241m=\u001B[39m squared_loss\n\u001B[0;32m      6\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m epoch \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mrange\u001B[39m(num_epochs):\n",
      "\u001B[1;31mNameError\u001B[0m: name 'linreg' is not defined"
     ]
    }
   ],
   "source": [
    "lr = 0.03\n",
    "num_epochs = 3\n",
    "net = linreg\n",
    "loss = squared_loss\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    for X,y in data_iter(batch_size,features,labels):\n",
    "        l = loss(net(X,w,b),y)  # X 和 y 的小批量损失\n",
    "        # 因为 L 形状是（batch_size，1），而不是一个标量。L中的所有元素被加到一起，并以此计算关于[w,b]的梯度\n",
    "        l.sum().backward()\n",
    "        sgd([w,b],lr,batch_size) # 使用参数的梯度更新参数\n",
    "    with torch.no_grad():\n",
    "        train_l = loss(net(features,w,b),labels)\n",
    "        print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a19123ca",
   "metadata": {},
   "source": [
    "因为我们使用的是自己合成的数据集，所以我们知道真正的参数是什么。 因此，我们可以通过比较真实参数和通过训练学到的参数来评估训练的成功程度。 事实上，真实参数和通过训练学到的参数确实非常接近。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "18a85ccb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "w的估计误差: tensor([ 1.9737, -3.4000], grad_fn=<SubBackward0>)\n",
      "b的估计误差: tensor([4.2000], grad_fn=<RsubBackward1>)\n"
     ]
    }
   ],
   "source": [
    "print(f'w的估计误差: {true_w - w.reshape(true_w.shape)}')\n",
    "print(f'b的估计误差: {true_b - b}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ml",
   "language": "python",
   "name": "ml"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}